import torch
import numpy as np
from torch.utils.data.dataloader import default_collate

from scalelsd.base.csrc import _C

class HAFMencoder(object):
    def __init__(self, dis_th = 10, ang_th = 0):
        self.dis_th = dis_th
        self.ang_th = ang_th

    def __call__(self,annotations):
        targets = []
        metas   = []
        batch_size = annotations['batch_size']
        stride = annotations['stride']
        for batch_id in range(batch_size):

            junctions = annotations['junctions'][batch_id].clone()[:,[1,0]]/float(stride)

            width = annotations['width']//stride
            height = annotations['height']//stride
            edge_indices = annotations['line_map'][batch_id].triu().nonzero()
            
            t, m = self.encoding_single_image(junctions,edge_indices,height,width)
            
            targets.append(t)
            metas.append(m)
        
        return default_collate(targets),metas

    def adjacent_matrix(self, n, edges, device):
        mat = torch.zeros(n+1,n+1,dtype=torch.bool,device=device)
        if edges.size(0)>0:
            mat[edges[:,0], edges[:,1]] += True
            mat[edges[:,1], edges[:,0]] += True
        return mat

    def lines2hafm(self, lines, height, width):
        device = lines.device
        if lines.shape[0] == 0:
            hafm_ang = torch.zeros((3,height,width),device=device)
            hafm_dis = torch.zeros((1,height,width),device=device)
            hafm_mask = torch.zeros((1,height,width),device=device)
            return torch.zeros((3,height,width),device=device), torch.zeros((1,height,width),device=device), torch.zeros((1,height,width),device=device)
        
        lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0))
        dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None]
        def _normalize(inp):
            mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1])
            return inp/(mag+1e-6)
        md_map = _normalize(lmap[:2])
        st_map = _normalize(lmap[2:4])
        ed_map = _normalize(lmap[4:])
        st_map = lmap[2:4]
        ed_map = lmap[4:]

        md_ = md_map.reshape(2,-1).t()
        st_ = st_map.reshape(2,-1).t()
        ed_ = ed_map.reshape(2,-1).t()
        Rt = torch.cat(
                (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2),
                 torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
        R = torch.cat(
                (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2),
                 torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
        #Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t()
        #Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t()
        Rtst_ = torch.bmm(Rt, st_[:,:,None]).squeeze(-1).t()
        Rted_ = torch.bmm(Rt, ed_[:,:,None]).squeeze(-1).t()
        swap_mask = (Rtst_[1]<0)*(Rted_[1]>0)
        pos_ = Rtst_.clone()
        neg_ = Rted_.clone()
        temp = pos_[:,swap_mask]
        pos_[:,swap_mask] = neg_[:,swap_mask]
        neg_[:,swap_mask] = temp

        pos_[0] = pos_[0].clamp(min=1e-9)
        pos_[1] = pos_[1].clamp(min=1e-9)
        neg_[0] = neg_[0].clamp(min=1e-9)
        neg_[1] = neg_[1].clamp(max=-1e-9)
        
        mask = (dismap.view(-1)<=self.dis_th).float()

        pos_map = pos_.reshape(-1,height,width)
        neg_map = neg_.reshape(-1,height,width)

        md_angle  = torch.atan2(md_map[1], md_map[0])
        pos_angle = torch.atan2(pos_map[1],pos_map[0])
        neg_angle = torch.atan2(neg_map[1],neg_map[0])

        mask *= (pos_angle.reshape(-1)>self.ang_th*np.pi/2.0)
        mask *= (neg_angle.reshape(-1)<-self.ang_th*np.pi/2.0)

        pos_angle_n = pos_angle/(np.pi/2)
        neg_angle_n = -neg_angle/(np.pi/2)
        md_angle_n  = md_angle/(np.pi*2) + 0.5
        mask    = mask.reshape(height,width)


        hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0)
        hafm_dis   = dismap.clamp(max=self.dis_th)/self.dis_th
        mask = mask[None]
        return hafm_ang, hafm_dis, mask

    def encoding_single_image(self, junctions, edge_indices, height, width):
        device = junctions.device

        # jmap = torch.zeros((height,width),device=device)
        # joff = torch.zeros((2,height,width),device=device,dtype=torch.float32)
        jmap = np.zeros((height,width),dtype=np.float32)
        joff = np.zeros((2,height,width),dtype=np.float32)

        dx, dy = np.meshgrid(np.arange(width), np.arange(height))
        # gaussian = np.exp(-(dx**2+dy**2)/2.0/2.0**2)

        if junctions.shape[0] > 0:
            junctions_np = junctions.cpu().numpy()
            xint, yint = junctions_np[:,0].astype(np.int32), junctions_np[:,1].astype(np.int32)
            off_x = junctions_np[:,0] - np.floor(junctions_np[:,0]) - 0.5
            off_y = junctions_np[:,1] - np.floor(junctions_np[:,1]) - 0.5
            
            jmap[yint,xint] = 1#= jmap[yint,xint] + 1
            joff[0,yint,xint] = off_x
            joff[1,yint,xint] = off_y

            lines = junctions[edge_indices].reshape(-1,4)        
            pos_mat = self.adjacent_matrix(junctions.size(0), edge_indices, device)
            labels = torch.ones((lines.shape[0],),device=device)
        else:
            lines = torch.empty((0,4),device=device)
            pos_mat = None
            labels = None
        # for _x,_y in junctions.cpu().numpy():
        #     _map = np.exp(-((dx-_x)**2+(dy-_y)**2)/2.0/8.0**2)
        #     _map /= _map.max()
        #     jmap = np.maximum(jmap,_map)
        # import matplotlib.pyplot as plt
        # import pdb; pdb.set_trace()
        jmap = torch.from_numpy(jmap).to(device)
        joff = torch.from_numpy(joff).to(device)
        hafm_ang, hafm_dis, hafm_mask = self.lines2hafm(lines,height,width)
        

        target = {
            'jloc': jmap[None],
            'joff': joff,
            'md': hafm_ang,
            'dis': hafm_dis,
            'mask': hafm_mask
        }

        meta = {
            'junc': junctions,
            'lines': lines,
            'Lpos': pos_mat,
            'lpre': lines,
            'lpre_label': labels
        }
        return target, meta
    